{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "37c85ec5",
   "metadata": {},
   "source": [
    "# Mistral Embed\n",
    "\n",
    "The Mistral Embedding though not popular -- is an interesting candidate for experimenting with Binary Quantization because of it's multilingual capabilities in European languages e.g. English, French, German. Here, we use embedding created for English text though.\n",
    "\n",
    "## Setting up the environment\n",
    "\n",
    "Our dependencies are specifies in the pyproject.toml files which ships with this notebook. You can install them using poetry by running the following command in the terminal:\n",
    "\n",
    "```bash\n",
    "poetry install --no-root\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1af4a9ca-bee3-433a-a87a-64bc4c55af75",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json\n",
    "import os\n",
    "\n",
    "import loguru\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "# Code of your application, which uses environment variables (e.g. from `os.environ` or\n",
    "# `os.getenv`) as if they came from the actual environment.\n",
    "from datasets import load_dataset\n",
    "from dotenv import load_dotenv\n",
    "from qdrant_client import QdrantClient, models\n",
    "from qdrant_client.models import PointStruct\n",
    "from tqdm import tqdm\n",
    "\n",
    "load_dotenv()  # take environment variables from .env.\n",
    "\n",
    "logger = loguru.logger\n",
    "logger.add(\"logs.log\", format=\"{time} {level} {message}\", level=\"INFO\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcc0520d",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset(\n",
    "    \"nirantk/dbpedia-entities-google-palm-gemini-embedding-001-100K\",\n",
    "    streaming=False,\n",
    "    split=\"train\",\n",
    ")\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "919204f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset.remove_columns(column_names=[\"embedding\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9262e94c",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset.map(lambda x: {\"combined_text\": f\"{x['title']}\\n{x['text']}\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f417fe0",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90fadc0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mistralai.client import MistralClient\n",
    "\n",
    "api_key = os.environ[\"MISTRAL_API_KEY\"]\n",
    "client = MistralClient(api_key=api_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4a33176",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined_text = dataset[\"combined_text\"]\n",
    "\n",
    "bs = 10\n",
    "response_objects = []\n",
    "for i in tqdm(range(0, len(combined_text), bs)):\n",
    "    this_batch = list(combined_text[i : i + bs])\n",
    "    embeddings_batch_response = client.embeddings(\n",
    "        model=\"mistral-embed\", input=this_batch\n",
    "    )\n",
    "\n",
    "    response_objects.append(embeddings_batch_response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a65e369e",
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_responses = [r.data for r in response_objects]\n",
    "# flatten the list of lists\n",
    "embedding_objects = [item for sublist in embedding_responses for item in sublist]\n",
    "embeddings = [e.embedding for e in embedding_objects]\n",
    "\n",
    "dataset = dataset.add_column(\"embedding\", embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa03da6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset.push_to_hub(\"nirantk/dbpedia-entities-mistral-embeddings-100K\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8ed651f",
   "metadata": {},
   "source": [
    "# Use Dataset from Huggingface Hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "fa42e94c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading readme: 100%|██████████| 432/432 [00:00<00:00, 2.39MB/s]\n",
      "Downloading data: 100%|██████████| 126M/126M [00:30<00:00, 4.15MB/s] \n",
      "Downloading data: 100%|██████████| 126M/126M [00:25<00:00, 5.01MB/s] \n",
      "Generating train split: 100%|██████████| 100000/100000 [00:01<00:00, 71302.72 examples/s]\n"
     ]
    }
   ],
   "source": [
    "dataset = load_dataset(\n",
    "    \"nirantk/dbpedia-entities-mistral-embeddings-100K\",\n",
    "    streaming=False,\n",
    "    split=\"train\",\n",
    ")\n",
    "points = [\n",
    "    {\n",
    "        \"id\": i,\n",
    "        \"vector\": embedding,\n",
    "        \"payload\": {\"text\": data[\"text\"], \"title\": data[\"title\"]},\n",
    "    }\n",
    "    for i, (embedding, data) in enumerate(zip(dataset[\"embedding\"], dataset))\n",
    "]\n",
    "points = [PointStruct(**point) for point in points]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "47e3baef-2746-47ee-a195-1f1ba37b44b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "client = QdrantClient(\n",
    "    url=os.getenv(\"QDRANT_URL\"),\n",
    "    api_key=os.getenv(\"QDRANT_API_KEY\"),\n",
    "    timeout=100,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cd6a7e3",
   "metadata": {},
   "source": [
    "# Setting up a Collection with Binary Quantization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b1c8eb96",
   "metadata": {},
   "outputs": [],
   "source": [
    "collection_name = \"mistral-embed\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "db5bb6fb-f93e-42c8-b440-d0b771ef8dab",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "client.recreate_collection(\n",
    "    collection_name=collection_name,\n",
    "    vectors_config=models.VectorParams(\n",
    "        size=1024,\n",
    "        distance=models.Distance.COSINE,\n",
    "    ),\n",
    "    optimizers_config=models.OptimizersConfigDiff(\n",
    "        default_segment_number=5,\n",
    "        indexing_threshold=0,\n",
    "    ),\n",
    "    quantization_config=models.BinaryQuantization(\n",
    "        binary=models.BinaryQuantizationConfig(always_ram=True),\n",
    "    ),\n",
    "    shard_number=2,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "1568c36b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m2024-01-15 16:03:47.633\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m4\u001b[0m - \u001b[1mCollection is empty. Begin upsert.\u001b[0m\n",
      "100%|██████████| 100/100 [04:07<00:00,  2.48s/it]\n"
     ]
    }
   ],
   "source": [
    "collection_info = client.get_collection(collection_name=collection_name)\n",
    "\n",
    "if collection_info.vectors_count == 0:\n",
    "    logger.info(\"Collection is empty. Begin upsert.\")\n",
    "    bs = 1000  # Batch size\n",
    "    for i in tqdm(range(0, len(points), bs)):\n",
    "        slice_points = points[i : i + bs]  # Create a slice of bs points\n",
    "        client.upsert(collection_name=collection_name, points=slice_points)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d3640b36-adcb-41f5-9e2c-e3ec74b1d82f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "100000"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "collection_info = client.get_collection(collection_name=collection_name)\n",
    "collection_info.vectors_count"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d04ca2b7",
   "metadata": {},
   "source": [
    "### Turn on Indexing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "ef40c18a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "client.update_collection(\n",
    "    collection_name=f\"{collection_name}\",\n",
    "    optimizer_config=models.OptimizersConfigDiff(indexing_threshold=20000),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "ac42c816-90c8-4a2b-9358-ed9f40283cc6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[ScoredPoint(id=32, version=0, score=1.0, payload={'text': 'Sobrassada (Catalan pronunciation: [soβɾəˈsaðə]; Spanish: sobrasada) is a raw, cured sausage from the Balearic Islands made with ground pork, paprika and salt and other spices. Sobrassada, along with botifarró are traditional Balearic sausage meat products prepared in the laborious but festive rites that still mark the autumn and winter pig slaughter known as a matança (in Spanish, matanza) in Majorca and Eivissa.', 'title': 'Sobrassada'}, vector=None, shard_key=None),\n",
       " ScoredPoint(id=98204, version=98, score=0.806023, payload={'text': 'Krakowska (pronounced /krəˈkɒvskə/ krə-KOV-skə) is a type of Polish sausage (kielbasa), usually served as a cold cut. The name derives from the city of Kraków (mediaeval capital of the Polish-Lithuanian Commonwealth till late 16th century). It is made from cuts of lean pork, seasoned with pepper, allspice, coriander, and garlic, packed into large casings, and smoked. English speaking countries often use the German name for the sausage, Krakauer (/krɑːˈkaʊər/ krah-COW-er).', 'title': 'Krakowska'}, vector=None, shard_key=None),\n",
       " ScoredPoint(id=89163, version=89, score=0.7999752, payload={'text': 'Pallars Sobirà (Catalan pronunciation: [pəˈʎaz suβiˈɾa], locally: [paˈʎaz soβiˈɾa]) is a comarca (comparable to a county or shire in much of the English-speaking world) in the mountainous northwest of Catalonia, Spain. The name means \"Upper Pallars\", distinguishing it from the more populous (and less mountainous) Pallars Jussà to its southwest.', 'title': 'Pallars Sobirà'}, vector=None, shard_key=None),\n",
       " ScoredPoint(id=72701, version=72, score=0.79771745, payload={'text': 'Falukorv is a Swedish sausage (korv in Swedish) made of a grated mixture of smoked pork and beef or veal with potato starch flour, onion, salt and mild spices. Falukorv is a cooked sausage, so it can be eaten without any further preparation. Some Swedes use it as a sandwich ingredient, much like ham or turkey.', 'title': 'Falukorv'}, vector=None, shard_key=None),\n",
       " ScoredPoint(id=99566, version=99, score=0.7943326, payload={'text': \"Bresaola (pronounced [breˈzaːola]), or in Italian dialect brisaola, is air-dried, salted beef that has been aged two or three months until it becomes hard and turns a dark red, almost purple colour. It is made from top (inside) round, and is lean and tender, with a sweet, musty smell. It originated in Valtellina, a valley in the Alps of northern Italy's Lombardy region.The word comes from the diminutive of Lombard bresada (braised).\", 'title': 'Bresaola'}, vector=None, shard_key=None)]"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "client.search(\n",
    "    collection_name=f\"{collection_name}\",\n",
    "    query_vector=points[32].vector,\n",
    "    search_params=models.SearchParams(\n",
    "        quantization=models.QuantizationSearchParams(\n",
    "            ignore=False,\n",
    "            rescore=False,\n",
    "            oversampling=2.0,\n",
    "        ),\n",
    "        exact=True,\n",
    "    ),\n",
    "    limit=5,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "c59e5ae6",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = dataset.train_test_split(test_size=0.1, shuffle=True, seed=37)[\"test\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b35e254",
   "metadata": {},
   "outputs": [],
   "source": [
    "oversampling_range = np.arange(1.0, 3.1, 1.0)\n",
    "rescore_range = [True, False]\n",
    "\n",
    "\n",
    "def parameterized_search(\n",
    "    point,\n",
    "    oversampling: float,\n",
    "    rescore: bool,\n",
    "    exact: bool,\n",
    "    collection_name: str,\n",
    "    ignore: bool = False,\n",
    "    limit: int = 10,\n",
    "):\n",
    "    if exact:\n",
    "        return client.search(\n",
    "            collection_name=collection_name,\n",
    "            query_vector=point.vector,\n",
    "            search_params=models.SearchParams(exact=exact),\n",
    "            limit=limit,\n",
    "        )\n",
    "    else:\n",
    "        return client.search(\n",
    "            collection_name=collection_name,\n",
    "            query_vector=point.vector,\n",
    "            search_params=models.SearchParams(\n",
    "                quantization=models.QuantizationSearchParams(\n",
    "                    ignore=ignore,\n",
    "                    rescore=rescore,\n",
    "                    oversampling=oversampling,\n",
    "                ),\n",
    "                exact=exact,\n",
    "            ),\n",
    "            limit=limit,\n",
    "        )\n",
    "\n",
    "\n",
    "results = []\n",
    "with open(\"results.json\", \"w+\") as f:\n",
    "    for point in tqdm(points[10:100]):\n",
    "        # print(element.payload[\"text\"])\n",
    "        # print(\"Oversampling\")\n",
    "\n",
    "        ## Running Grid Search\n",
    "        for oversampling in oversampling_range:\n",
    "            for rescore in rescore_range:\n",
    "                limit_range = [100, 50, 20, 10, 5, 1]\n",
    "                for limit in limit_range:\n",
    "                    try:\n",
    "                        exact = parameterized_search(\n",
    "                            point=point,\n",
    "                            oversampling=oversampling,\n",
    "                            rescore=rescore,\n",
    "                            exact=True,\n",
    "                            collection_name=collection_name,\n",
    "                            limit=limit,\n",
    "                        )\n",
    "                        hnsw = parameterized_search(\n",
    "                            point=point,\n",
    "                            oversampling=oversampling,\n",
    "                            rescore=rescore,\n",
    "                            exact=False,\n",
    "                            collection_name=collection_name,\n",
    "                            limit=limit,\n",
    "                        )\n",
    "                    except Exception as e:\n",
    "                        print(f\"Skipping point: {point}\\n{e}\")\n",
    "                        continue\n",
    "\n",
    "                    exact_ids = [item.id for item in exact]\n",
    "                    hnsw_ids = [item.id for item in hnsw]\n",
    "                    # logger.info(f\"Exact: {exact_ids}\")\n",
    "                    # logger.info(f\"HNSW: {hnsw_ids}\")\n",
    "\n",
    "                    accuracy = len(set(exact_ids) & set(hnsw_ids)) / len(exact_ids)\n",
    "\n",
    "                    if accuracy is None:\n",
    "                        continue\n",
    "\n",
    "                    result = {\n",
    "                        \"query_id\": point.id,\n",
    "                        \"oversampling\": oversampling,\n",
    "                        \"rescore\": rescore,\n",
    "                        \"limit\": limit,\n",
    "                        \"accuracy\": accuracy,\n",
    "                    }\n",
    "                    f.write(json.dumps(result))\n",
    "                    f.write(\"\\n\")\n",
    "                    logger.info(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "10bf7f19",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "results = pd.read_json(\"results.json\", lines=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8b99eb24",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr:last-of-type th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th>oversampling</th>\n",
       "      <th colspan=\"2\" halign=\"left\">1</th>\n",
       "      <th colspan=\"2\" halign=\"left\">2</th>\n",
       "      <th colspan=\"2\" halign=\"left\">3</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>rescore</th>\n",
       "      <th>False</th>\n",
       "      <th>True</th>\n",
       "      <th>False</th>\n",
       "      <th>True</th>\n",
       "      <th>False</th>\n",
       "      <th>True</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>limit</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>0.534444</td>\n",
       "      <td>0.857778</td>\n",
       "      <td>0.534444</td>\n",
       "      <td>0.918889</td>\n",
       "      <td>0.533333</td>\n",
       "      <td>0.941111</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>0.508333</td>\n",
       "      <td>0.837778</td>\n",
       "      <td>0.508333</td>\n",
       "      <td>0.903889</td>\n",
       "      <td>0.508333</td>\n",
       "      <td>0.927778</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50</th>\n",
       "      <td>0.492222</td>\n",
       "      <td>0.834444</td>\n",
       "      <td>0.492222</td>\n",
       "      <td>0.903556</td>\n",
       "      <td>0.492889</td>\n",
       "      <td>0.940889</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>100</th>\n",
       "      <td>0.499111</td>\n",
       "      <td>0.845444</td>\n",
       "      <td>0.498556</td>\n",
       "      <td>0.918333</td>\n",
       "      <td>0.497667</td>\n",
       "      <td>0.944556</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "oversampling         1                   2                   3          \n",
       "rescore          False     True      False     True      False     True \n",
       "limit                                                                   \n",
       "10            0.534444  0.857778  0.534444  0.918889  0.533333  0.941111\n",
       "20            0.508333  0.837778  0.508333  0.903889  0.508333  0.927778\n",
       "50            0.492222  0.834444  0.492222  0.903556  0.492889  0.940889\n",
       "100           0.499111  0.845444  0.498556  0.918333  0.497667  0.944556"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# results.to_csv(\"results.csv\", index=False)\n",
    "average_accuracy = results[results[\"limit\"] != 1]\n",
    "average_accuracy = average_accuracy[average_accuracy[\"limit\"] != 5]\n",
    "average_accuracy = average_accuracy.groupby([\"oversampling\", \"rescore\", \"limit\"])[\n",
    "    \"accuracy\"\n",
    "].mean()\n",
    "average_accuracy = average_accuracy.reset_index()\n",
    "acc = average_accuracy.pivot(\n",
    "    index=\"limit\", columns=[\"oversampling\", \"rescore\"], values=\"accuracy\"\n",
    ")\n",
    "acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "17caf1c5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "|   limit |   (3, True) |\n",
      "|--------:|------------:|\n",
      "|      10 |    0.941111 |\n",
      "|      20 |    0.927778 |\n",
      "|      50 |    0.940889 |\n",
      "|     100 |    0.944556 |\n"
     ]
    }
   ],
   "source": [
    "markdown_table = acc.loc[:, (3.0, True)].to_markdown()\n",
    "print(markdown_table)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
